import torch
import torch.nn as nn
import psutil


class Catcher(nn.Module):
    def __init__(self, module, inps, cache):
        super().__init__()
        self.module = module
        self.inps = inps
        self.cache = cache

    def forward(self, inp, **kwargs):
        inp = inp.cpu()
        self.inps[self.cache["i"]] = inp.cpu()
        self.cache["i"] += 1
        # Move everything to cpu otherwise it will stay in cuda later on
        for key, value in kwargs.items():
            if isinstance(value, torch.Tensor):
                kwargs[key] = value.cpu()
        if "attention_mask" in kwargs:
            self.cache["catcher_attention_mask"] = kwargs["attention_mask"]
        else:
            self.cache["catcher_attention_mask"] = None
        self.cache["catcher_position_ids"] = kwargs["position_ids"]
        raise ValueError


def print_gpu_info(device):

    print('Checking hardware run status:')

    # VRAM
    free, total = torch.cuda.mem_get_info(device)
    mem_used_MB = (total - free) / 1024**2
    print(f" - VRAM (GPU): [{mem_used_MB}/{total / 1024**2}]")

    # RAm and CPU
    print(f" - RAM: [{psutil.virtual_memory().percent}/100.0]")
    print(f" - CPU: [{psutil.cpu_percent()}/100.0]")